#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Aug  4 12:53:22 2021

@author: Burak
"""
"""
generate results for shifted mean, increased variance and perfect intervention.

run for our algorithm: generate results for I recovery
run for our algorithm + UT-IGSP: generate results for I recovery (also I_parent recovery)

Use Gauss CI tests for UT-IGSP

"""

import numpy as np
import pickle
from config import SIMULATIONS_ESTIMATED_FOLDER
from functions import algorithm_sample
from functions_utigsp import run_utigsp
from helpers import create_intervention, sample, counter

#%%
def run_ours_repeated(p_list,density_list,n_samples_list,I_size,n_repeat,\
                      shift=0.0,plus_variance=0.0,B_distortion_amplitude=0,perfect_intervention=False,\
                          rho=1,lambda_l1=0.2,single_threshold=0.1,pair_l1=0.1,pair_threshold=5e-3,parent_l1 = 0.1):

    I_tp = np.zeros((n_repeat,len(p_list),len(density_list),len(n_samples_list)))
    I_fp = np.zeros((n_repeat,len(p_list),len(density_list),len(n_samples_list)))
    I_fn = np.zeros((n_repeat,len(p_list),len(density_list),len(n_samples_list)))
    e_tp = np.zeros((n_repeat,len(p_list),len(density_list),len(n_samples_list)))
    e_fp = np.zeros((n_repeat,len(p_list),len(density_list),len(n_samples_list)))
    e_fn = np.zeros((n_repeat,len(p_list),len(density_list),len(n_samples_list)))
    
    time_ours = np.zeros((n_repeat,len(p_list),len(density_list),len(n_samples_list)))
    
    for i in range(n_repeat):
        for j in range(len(p_list)):
            for k in range(len(density_list)):
                B1,G1,mu1,variance1,Omega1,Theta1,Cov1,B2,G2,mu2,variance2,Omega2,Theta2,Cov2,Delta_Theta,S_Delta,I \
                    = create_intervention(p_list[j],I_size,density_list[k],mu=0,shift=shift,plus_variance=plus_variance,variance=1.0,\
                                      B_distortion_amplitude=B_distortion_amplitude,perfect_intervention=perfect_intervention)
                
                #diff_marginal_noise = np.abs(1/np.diag(Cov1)-1/np.diag(Cov2))
                #J0 = np.intersect1d(np.where(diff_marginal_noise<1e-6)[0],S_Delta)
                I_parents = [np.where(B1[:,i])[0].tolist() for i in I]
                #Delta_GT = Theta2-Theta1
                
                for s in range(len(n_samples_list)):
                    X1 = sample(B1,mu1,variance1,n_samples_list[s])
                    X2 = sample(B2,mu2,variance2,n_samples_list[s])
                    S1 = (X1.T@X1)/n_samples_list[s]
                    S2 = (X2.T@X2)/n_samples_list[s]
                    
                    I_hat, I_hat_parents, N_lists, A_groups, t_past = algorithm_sample(S1,S2,lambda_l1,rho,single_threshold,\
                                       pair_l1,pair_threshold,parent_l1,return_parents=True,verbose=False,Delta_hat_parent_check=True)
        
                    tp_i, fp_i, fn_i, tp_e, fp_e, fn_e = counter(I,I_hat,I_parents,I_hat_parents)
                    I_tp[i,j,k,s] = tp_i; I_fp[i,j,k,s] = fp_i; I_fn[i,j,k,s] = fn_i; 
                    e_tp[i,j,k,s] = tp_e; e_fp[i,j,k,s] = fp_e; e_fn[i,j,k,s] = fn_e
                    time_ours[i,j,k,s] = t_past
                    print(i,j,k,s)

    res = {'n_repeat':n_repeat,'p_list':p_list,'density_list':density_list,'I_size':I_size,\
               'n_samples_list':n_samples_list,'I_tp':I_tp, 'I_fp':I_fp, 'I_fn':I_fn, \
           'e_tp':e_tp, 'e_fp':e_fp,'e_fn':e_fn, 'time':time_ours}
    
    return res


    
#%%

rho = 1.0
lambda_l1 = 0.2    # for S_Delta estimation, and pruning
single_threshold = 0.1     # for J0 estimation
pair_l1 = 0.1               # for J0_k estimation
pair_threshold = 5e-3       # for J0_k estimation, throwaway very small ones
parent_l1 = 0.1          # for post-parent estimation     
n_max_iter = 500
stop_cond = 1e-6
verbose = False
tol = 1e-9

n_repeat = 10
p_list = [100,250,500,1000]
density_list = [2.5]
I_size = 5
n_samples_list = [5000]

'run ours on increased variance intervention'

res_inc = run_ours_repeated(p_list, density_list, n_samples_list, I_size, n_repeat, shift=0.0,plus_variance=1.0,\
                            B_distortion_amplitude=0.0,perfect_intervention=False,rho=rho,lambda_l1=lambda_l1,\
                                single_threshold=single_threshold,pair_l1=pair_l1,\
                                    pair_threshold=pair_threshold,parent_l1=parent_l1)

    
#f = open(SIMULATIONS_ESTIMATED_FOLDER+'/increased_variance_1.pkl','wb')
#pickle.dump(res_inc,f)
#f.close()
    
#I_precision, I_recall, I_f1, e_precision, e_recall, e_f1 = scores(I_tp,I_fp,I_fn,e_tp,e_fp,e_fn)

#%%
'run ours on shifted mean intervention'

res_shift = run_ours_repeated(p_list, density_list, n_samples_list, I_size, n_repeat, shift=1.0,plus_variance=0.0,\
                            B_distortion_amplitude=0.0,perfect_intervention=False,rho=rho,lambda_l1=lambda_l1,\
                                single_threshold=single_threshold,pair_l1=pair_l1,\
                                    pair_threshold=pair_threshold,parent_l1=parent_l1)

    
#f = open(SIMULATIONS_ESTIMATED_FOLDER+'/shifted_mean_1.pkl','wb')
#pickle.dump(res_shift,f)
#f.close()

        

